# -*- coding:utf-8 -*-
import baostock as bs
import pandas as pd
from baostock.data.resultset import ResultData

from dao import base_dao
from datasource import session_factory
from model.stock_model import Stock
from system.sys_logger import Logfactory

from datetime import datetime

from utils import uuid_util

"""
方法说明：获取指定交易日期所有股票列表。通过API接口获取证券代码及股票交易状态信息，与日K线数据同时更新。可以通过参数‘某交易日’获取数据（包括：A股、指数），数据范围同接口query_history_k_data_plus()。
返回类型：pandas的DataFrame类型。
更新时间：与日K线同时更新。
"""


def get_all_stock(trade_date):
    if (trade_date is None) or (len(trade_date) == 0):
        trade_date = datetime.datetime.now().strftime('%Y-%m-%d')

    data_list = []
    rs = bs.query_all_stock(day=trade_date)

    if rs.error_code != '0':
        Logfactory.logger.error(
            "get_all_stock failed . and respond error_code=%s "
            "error_msg=%s", rs.error_code, rs.error_msg)
    else:
        Logfactory.logger.info(
            "get_all_stock success . and respond error_code=%s "
            "error_msg=%s", rs.error_code, rs.error_msg)

        while rs.next():
            # 获取一条记录，将记录合并在一起
            data_list.append(rs.get_row_data())
        result = pd.DataFrame(data_list, columns=rs.fields)

        stock_list = build_stock_list(result)

        base_dao.save_all(stock_list)
        Logfactory.logger.info("save stock succes")


def build_stock_list(stock_frame: pd.DataFrame):
    stock_list = []

    if stock_frame is None or len(stock_frame) == 0:
        return stock_list

    for index, row in stock_frame.iterrows():
        stock = build_stock(row)
        if stock is not None:
            stock_list.append(stock)

    return stock_list


def build_stock(row: pd.Series):
    if row is None:
        Logfactory.logger.warn("stock is None")
        return None
    else:
        stock_code = row["code"]
        stock_name = row["code_name"]
        trade_status = row["tradeStatus"]

        stock = Stock(id=uuid_util.generate_uuid(), stock_code=stock_code, stock_name=stock_name,
                      trade_status=trade_status)
        return stock


"""
根据 证券编码 查询证券
如果证券编码为空，则默认查询所有证券列表
"""


def select_all_stock(stock_code):
    stock_list = []
    session = session_factory.get_session()
    try:
        query = None
        if (stock_code is None) or len(stock_code) == 0:
            query = session.query(Stock)
        else:
            query = session.query(Stock).filter(Stock.stock_code == stock_code)

        if query is not None:
            # 遍历所有用户
            for stock in query.all():
                stock_list.append(stock)
    except Exception as e:
        Logfactory.logger.exception("select_all_stock exception and error message is %s", e)
        raise e
    finally:
        session_factory.close_session(session)
        return stock_list
